package sqlcv1

import (
	"context"
	"fmt"

	"github.com/jackc/pgx/v5/pgtype"
)

type CutoverPayloadToInsert struct {
	TenantID            pgtype.UUID
	ID                  int64
	InsertedAt          pgtype.Timestamptz
	ExternalID          pgtype.UUID
	Type                V1PayloadType
	ExternalLocationKey string
}

func InsertCutOverPayloadsIntoTempTable(ctx context.Context, tx DBTX, tableName string, payloads []CutoverPayloadToInsert) (int64, error) {
	tenantIds := make([]pgtype.UUID, 0, len(payloads))
	ids := make([]int64, 0, len(payloads))
	insertedAts := make([]pgtype.Timestamptz, 0, len(payloads))
	externalIds := make([]pgtype.UUID, 0, len(payloads))
	types := make([]string, 0, len(payloads))
	locations := make([]string, 0, len(payloads))
	externalLocationKeys := make([]string, 0, len(payloads))

	for _, payload := range payloads {
		externalIds = append(externalIds, payload.ExternalID)
		tenantIds = append(tenantIds, payload.TenantID)
		ids = append(ids, payload.ID)
		insertedAts = append(insertedAts, payload.InsertedAt)
		types = append(types, string(payload.Type))
		locations = append(locations, string(V1PayloadLocationEXTERNAL))
		externalLocationKeys = append(externalLocationKeys, string(payload.ExternalLocationKey))
	}

	row := tx.QueryRow(
		ctx,
		fmt.Sprintf(
			// we unfortunately need to use `INSERT INTO` instead of `COPY` here
			// because we can't have conflict resolution with `COPY`.
			`
				WITH inputs AS (
					SELECT
						UNNEST($1::UUID[]) AS tenant_id,
						UNNEST($2::BIGINT[]) AS id,
						UNNEST($3::TIMESTAMPTZ[]) AS inserted_at,
						UNNEST($4::UUID[]) AS external_id,
						UNNEST($5::TEXT[]) AS type,
						UNNEST($6::TEXT[]) AS location,
						UNNEST($7::TEXT[]) AS external_location_key
				), inserts AS (
					INSERT INTO %s (tenant_id, id, inserted_at, external_id, type, location, external_location_key, inline_content, updated_at)
					SELECT
						tenant_id,
						id,
						inserted_at,
						external_id,
						type::v1_payload_type,
						location::v1_payload_location,
						external_location_key,
						NULL,
						NOW()
					FROM inputs
					ORDER BY tenant_id, inserted_at, id, type
					ON CONFLICT(tenant_id, id, inserted_at, type) DO NOTHING
					RETURNING *
				)

				SELECT COUNT(*)
				FROM inserts
				`,
			tableName,
		),
		tenantIds,
		ids,
		insertedAts,
		externalIds,
		types,
		locations,
		externalLocationKeys,
	)

	var copyCount int64
	err := row.Scan(&copyCount)

	return copyCount, err
}

func ComparePartitionRowCounts(ctx context.Context, tx DBTX, tempPartitionName, sourcePartitionName string) (bool, error) {
	row := tx.QueryRow(
		ctx,
		fmt.Sprintf(
			`
				SELECT
					(SELECT COUNT(*) FROM %s) AS temp_partition_count,
					(SELECT COUNT(*) FROM %s) AS source_partition_count
			`,
			tempPartitionName,
			sourcePartitionName,
		),
	)

	var tempPartitionCount int64
	var sourcePartitionCount int64

	err := row.Scan(&tempPartitionCount, &sourcePartitionCount)

	if err != nil {
		return false, err
	}

	return tempPartitionCount == sourcePartitionCount, nil
}

const findV1PayloadPartitionsBeforeDate = `-- name: findV1PayloadPartitionsBeforeDate :many
WITH partitions AS (
    SELECT
        child.relname::text AS partition_name,
        SUBSTRING(pg_get_expr(child.relpartbound, child.oid) FROM 'FROM \(''([^'']+)')::DATE AS lower_bound,
        SUBSTRING(pg_get_expr(child.relpartbound, child.oid) FROM 'TO \(''([^'']+)')::DATE AS upper_bound
    FROM pg_inherits
    JOIN pg_class parent ON pg_inherits.inhparent = parent.oid
    JOIN pg_class child ON pg_inherits.inhrelid = child.oid
    WHERE parent.relname = 'v1_payload'
    ORDER BY child.relname
)

SELECT partition_name, lower_bound AS partition_date
FROM partitions
WHERE lower_bound <= $1::DATE
`

type FindV1PayloadPartitionsBeforeDateRow struct {
	PartitionName string      `json:"partition_name"`
	PartitionDate pgtype.Date `json:"partition_date"`
}

func (q *Queries) FindV1PayloadPartitionsBeforeDate(ctx context.Context, db DBTX, date pgtype.Date) ([]*FindV1PayloadPartitionsBeforeDateRow, error) {
	rows, err := db.Query(ctx, findV1PayloadPartitionsBeforeDate,
		date,
	)
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	var items []*FindV1PayloadPartitionsBeforeDateRow
	for rows.Next() {
		var i FindV1PayloadPartitionsBeforeDateRow
		if err := rows.Scan(
			&i.PartitionName,
			&i.PartitionDate,
		); err != nil {
			return nil, err
		}
		items = append(items, &i)
	}
	if err := rows.Err(); err != nil {
		return nil, err
	}
	return items, nil
}
